# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from collections import OrderedDict

from paddle.distributed.auto_parallel.static.dist_attribute import (
    DistTensorSpec,
    TensorDistAttr,
)
from paddle.distributed.fleet import auto
from paddle.framework import core


class TestRollSPMDRule(unittest.TestCase):
    def setUp(self):
        x_shape = [16, 16, 16]
        out_shape = [16, 16, 16]
        process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]])

        x_tensor_dist_attr = TensorDistAttr()
        x_tensor_dist_attr.dims_mapping = [-1, -1, -1]
        x_tensor_dist_attr.process_mesh = process_mesh
        self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
        out_grad_tensor_dist_attr = TensorDistAttr()
        out_grad_tensor_dist_attr.dims_mapping = [-1, -1, -1]
        out_grad_tensor_dist_attr.process_mesh = process_mesh
        self.out_grad_dist_tensor_spec = DistTensorSpec(
            out_shape, out_grad_tensor_dist_attr
        )

        self.rule = core.get_phi_spmd_rule("roll")
        self.attrs = OrderedDict()
        self.attrs['shifts'] = [1]
        self.attrs['axis'] = []

    def test_roll_forward(self):
        # axis = [], shifts = [1]
        # [0, 1, -1] --> [-1, -1, -1], [-1, -1, -1]
        self.attrs['axis'] = []
        self.attrs['shifts'] = [1]
        self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1])
        result_dist_attrs = self.rule.infer_forward(
            self.x_dist_tensor_spec,
            self.attrs['shifts'],
            self.attrs['axis'],
        )

        self.assertEqual(len(result_dist_attrs), 2)
        inferred_input_dist_attrs = result_dist_attrs[0]
        inferred_output_dist_attrs = result_dist_attrs[1]

        self.assertEqual(len(inferred_input_dist_attrs), 1)
        self.assertEqual(len(inferred_output_dist_attrs), 1)

        self.assertEqual(
            inferred_input_dist_attrs[0].dims_mapping, [-1, -1, -1]
        )
        self.assertEqual(
            inferred_output_dist_attrs[0].dims_mapping, [-1, -1, -1]
        )

        # axis = [0, 2], shifts = [1, 2]
        # [0, 1, -1] --> [-1, 1, -1], [-1, 1, -1]
        self.attrs['axis'] = [0, 2]
        self.attrs['shifts'] = [1, 2]
        self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1])
        result_dist_attrs = self.rule.infer_forward(
            self.x_dist_tensor_spec,
            self.attrs['shifts'],
            self.attrs['axis'],
        )

        self.assertEqual(len(result_dist_attrs), 2)
        inferred_input_dist_attrs = result_dist_attrs[0]
        inferred_output_dist_attrs = result_dist_attrs[1]

        self.assertEqual(len(inferred_input_dist_attrs), 1)
        self.assertEqual(len(inferred_output_dist_attrs), 1)

        self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [-1, 1, -1])
        self.assertEqual(
            inferred_output_dist_attrs[0].dims_mapping, [-1, 1, -1]
        )

    def test_roll_backward(self):
        # axis = [], shifts = [1]
        # [-1, -1, -1], [0, 1, -1] --> [-1, -1, -1], [-1, -1, -1], [-1, -1, -1]
        self.attrs['axis'] = []
        self.attrs['shifts'] = [1]
        self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1])
        self.out_grad_dist_tensor_spec.set_dims_mapping([0, 1, -1])
        result_dist_attrs = self.rule.infer_backward(
            self.x_dist_tensor_spec,
            self.out_grad_dist_tensor_spec,
            self.attrs['shifts'],
            self.attrs['axis'],
        )

        self.assertEqual(len(result_dist_attrs), 2)
        inferred_input_dist_attrs = result_dist_attrs[0]
        inferred_output_dist_attrs = result_dist_attrs[1]
        self.assertEqual(len(inferred_input_dist_attrs), 2)
        self.assertEqual(len(inferred_output_dist_attrs), 1)
        self.assertEqual(
            inferred_input_dist_attrs[0].dims_mapping, [-1, -1, -1]
        )
        self.assertEqual(
            inferred_input_dist_attrs[1].dims_mapping, [-1, -1, -1]
        )
        self.assertEqual(
            inferred_output_dist_attrs[0].dims_mapping, [-1, -1, -1]
        )

        # axis = [0, 2], shifts = [1, 2]
        # [-1, -1, -1], [0, 1, -1] --> [-1, 1, -1], [-1, 1, -1], [-1, 1, -1]
        self.attrs['axis'] = [0, 2]
        self.attrs['shifts'] = [1, 2]
        self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1])
        self.out_grad_dist_tensor_spec.set_dims_mapping([0, 1, -1])
        result_dist_attrs = self.rule.infer_backward(
            self.x_dist_tensor_spec,
            self.out_grad_dist_tensor_spec,
            self.attrs['shifts'],
            self.attrs['axis'],
        )

        self.assertEqual(len(result_dist_attrs), 2)
        inferred_input_dist_attrs = result_dist_attrs[0]
        inferred_output_dist_attrs = result_dist_attrs[1]
        self.assertEqual(len(inferred_input_dist_attrs), 2)
        self.assertEqual(len(inferred_output_dist_attrs), 1)
        self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [-1, 1, -1])
        self.assertEqual(inferred_input_dist_attrs[1].dims_mapping, [-1, 1, -1])
        self.assertEqual(
            inferred_output_dist_attrs[0].dims_mapping, [-1, 1, -1]
        )


if __name__ == "__main__":
    unittest.main()
